{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## First Neural Network: Image Classification \n",
    "\n",
    "Objectives:\n",
    "- Train a minimal image classifier on [MNIST](https://paperswithcode.com/dataset/mnist) using PyTorch\n",
    "- Usese PyTorch and torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The usual imports\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the data\n",
    "\n",
    "class ReshapeTransform:\n",
    "    def __init__(self, new_size):\n",
    "        self.new_size = new_size\n",
    "\n",
    "    def __call__(self, img):\n",
    "        return torch.reshape(img, self.new_size)\n",
    "\n",
    "transformations = transforms.Compose([\n",
    "                                transforms.ToTensor(),\n",
    "                                transforms.ConvertImageDtype(torch.float32),\n",
    "                                ReshapeTransform((-1,))\n",
    "                                ])\n",
    "\n",
    "trainset = torchvision.datasets.MNIST(root='./data', train=True,\n",
    "                                        download=True, transform=transformations)\n",
    "\n",
    "testset = torchvision.datasets.MNIST(root='./data', train=False,\n",
    "                                       download=True, transform=transformations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([60000, 28, 28]), torch.Size([10000, 28, 28]))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# check shape of data\n",
    "\n",
    "trainset.data.shape, testset.data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# data loader\n",
    "\n",
    "BATCH_SIZE = 128\n",
    "train_dataloader = torch.utils.data.DataLoader(trainset, \n",
    "                                               batch_size=BATCH_SIZE,\n",
    "                                               shuffle=True, \n",
    "                                               num_workers=0)\n",
    "\n",
    "test_dataloader = torch.utils.data.DataLoader(testset, \n",
    "                                              batch_size=BATCH_SIZE,\n",
    "                                              shuffle=False, \n",
    "                                              num_workers=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model\n",
    "\n",
    "model = nn.Sequential(nn.Linear(784, 512), nn.ReLU(), nn.Linear(512, 10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# training preparation\n",
    "\n",
    "trainer = torch.optim.RMSprop(model.parameters())\n",
    "loss = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_accuracy(output, target, batch_size):\n",
    "    # Obtain accuracy for training round\n",
    "    corrects = (torch.max(output, 1)[1].view(target.size()).data == target.data).sum()\n",
    "    accuracy = 100.0 * corrects/batch_size\n",
    "    return accuracy.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 | Train loss: 0.9943 | Train Accuracy: 91.7344\n",
      "Epoch: 2 | Train loss: 0.1334 | Train Accuracy: 95.9422\n",
      "Epoch: 3 | Train loss: 0.1030 | Train Accuracy: 96.8767\n",
      "Epoch: 4 | Train loss: 0.0845 | Train Accuracy: 97.4997\n",
      "Epoch: 5 | Train loss: 0.0735 | Train Accuracy: 97.8811\n"
     ]
    }
   ],
   "source": [
    "# train\n",
    "\n",
    "for ITER in range(5):\n",
    "    train_acc = 0.0\n",
    "    train_running_loss = 0.0\n",
    "\n",
    "    model.train()\n",
    "    for i, (X, y) in enumerate(train_dataloader):\n",
    "        output = model(X)\n",
    "        l = loss(output, y)\n",
    "\n",
    "        # update the parameters\n",
    "        l.backward()\n",
    "        trainer.step()\n",
    "        trainer.zero_grad()\n",
    "\n",
    "        # gather metrics\n",
    "        train_acc += get_accuracy(output, y, BATCH_SIZE)\n",
    "        train_running_loss += l.detach().item()\n",
    "\n",
    "    print('Epoch: %d | Train loss: %.4f | Train Accuracy: %.4f' \\\n",
    "          %(ITER+1, train_running_loss / (i+1),train_acc/(i+1)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Other things to try\n",
    "\n",
    "- Evaluate on test set\n",
    "- Plot loss curve\n",
    "- Add more layers to the model"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.13 ('play')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "cf9800998463bc980d70cdbacff0c7e9a10687346dc898569e92f016d6e252c9"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
